// Copyright (c) 2005 Brian Wellington (bwelling@xbill.org)

package org.xbill.DNS;

import java.io.*;
import java.net.*;
import java.nio.*;
import java.nio.channels.*;

final class TCPClient extends Client {

	public TCPClient(long endTime) throws IOException {
		super(SocketChannel.open(), endTime);
	}

	void bind(SocketAddress addr) throws IOException {
		SocketChannel channel = (SocketChannel) key.channel();
		channel.socket().bind(addr);
	}

	void connect(SocketAddress addr) throws IOException {
		SocketChannel channel = (SocketChannel) key.channel();
		if (channel.connect(addr))
			return;
		key.interestOps(SelectionKey.OP_CONNECT);
		try {
			while (!channel.finishConnect()) {
				if (!key.isConnectable())
					blockUntil(key, endTime);
			}
		} finally {
			if (key.isValid())
				key.interestOps(0);
		}
	}

	void send(byte[] data) throws IOException {
		SocketChannel channel = (SocketChannel) key.channel();
		verboseLog("TCP write", data);
		byte[] lengthArray = new byte[2];
		lengthArray[0] = (byte) (data.length >>> 8);
		lengthArray[1] = (byte) (data.length & 0xFF);
		ByteBuffer[] buffers = new ByteBuffer[2];
		buffers[0] = ByteBuffer.wrap(lengthArray);
		buffers[1] = ByteBuffer.wrap(data);
		int nsent = 0;
		key.interestOps(SelectionKey.OP_WRITE);
		try {
			while (nsent < data.length + 2) {
				if (key.isWritable()) {
					long n = channel.write(buffers);
					if (n < 0)
						throw new EOFException();
					nsent += (int) n;
					if (nsent < data.length + 2
							&& System.currentTimeMillis() > endTime)
						throw new SocketTimeoutException();
				} else
					blockUntil(key, endTime);
			}
		} finally {
			if (key.isValid())
				key.interestOps(0);
		}
	}

	private byte[] _recv(int length) throws IOException {
		SocketChannel channel = (SocketChannel) key.channel();
		int nrecvd = 0;
		byte[] data = new byte[length];
		ByteBuffer buffer = ByteBuffer.wrap(data);
		key.interestOps(SelectionKey.OP_READ);
		try {
			while (nrecvd < length) {
				if (key.isReadable()) {
					long n = channel.read(buffer);
					if (n < 0)
						throw new EOFException();
					nrecvd += (int) n;
					if (nrecvd < length && System.currentTimeMillis() > endTime)
						throw new SocketTimeoutException();
				} else
					blockUntil(key, endTime);
			}
		} finally {
			if (key.isValid())
				key.interestOps(0);
		}
		return data;
	}

	byte[] recv() throws IOException {
		byte[] buf = _recv(2);
		int length = ((buf[0] & 0xFF) << 8) + (buf[1] & 0xFF);
		byte[] data = _recv(length);
		verboseLog("TCP read", data);
		return data;
	}

	static byte[] sendrecv(SocketAddress local, SocketAddress remote,
			byte[] data, long endTime) throws IOException {
		TCPClient client = new TCPClient(endTime);
		try {
			if (local != null)
				client.bind(local);
			client.connect(remote);
			client.send(data);
			return client.recv();
		} finally {
			client.cleanup();
		}
	}

	static byte[] sendrecv(SocketAddress addr, byte[] data, long endTime)
			throws IOException {
		return sendrecv(null, addr, data, endTime);
	}

}
